from typing import Tuple
import jax
import jax.numpy as jnp
from common import Batch, InfoDict, Model, Params, PRNGKey


def pagar_update(key: PRNGKey, actor: Model, reward: Model, value: Model, discount: float, 
           batch: Batch, temperature: float, double: bool) -> Tuple[Model, InfoDict]:
  
    r1, r2 = reward(batch.observations, batch.actions, batch.next_observations * batch.masks.reshape(-1, 1))# + batch.observations * (1 - batch.masks.reshape(-1, 1)))
    if double:
        r = jnp.minimum(r1, r2)
    else:
        r = r1
   
    nxt_v = discount * batch.masks * value(batch.next_observations)
       
    q = r + nxt_v

    #v = value(batch.observations)
    a = q# - v
    exp_a = jnp.exp(a * temperature)
    exp_a = jnp.minimum(exp_a, 100.0)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        actor_loss = -(exp_a * log_probs).mean()

        return actor_loss, {'actor_loss': actor_loss, 'adv': r}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info
